{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Application of FL task\n",
    "from MLModel import *\n",
    "from FLModel import *\n",
    "from utils import *\n",
    "\n",
    "from torchvision import datasets, transforms\n",
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "#device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_cnn_mnist(num_users):\n",
    "    train = datasets.MNIST(root=\"~/data/\", train=True, download=True, transform=transforms.ToTensor())\n",
    "    train_data = train.data.float().unsqueeze(1)\n",
    "    train_label = train.targets\n",
    "\n",
    "    mean = train_data.mean()\n",
    "    std = train_data.std()\n",
    "    train_data = (train_data - mean) / std\n",
    "\n",
    "    test = datasets.MNIST(root=\"~/data/\", train=False, download=True, transform=transforms.ToTensor())\n",
    "    test_data = test.data.float().unsqueeze(1)\n",
    "    test_label = test.targets\n",
    "    test_data = (test_data - mean) / std\n",
    "\n",
    "    # split MNIST (training set) into non-iid data sets\n",
    "    non_iid = []\n",
    "    user_dict = mnist_noniid(train_label, num_users)\n",
    "    for i in range(num_users):\n",
    "        idx = user_dict[i]\n",
    "        d = train_data[idx]\n",
    "        targets = train_label[idx].float()\n",
    "        non_iid.append((d, targets))\n",
    "    non_iid.append((test_data.float(), test_label.float()))\n",
    "    return non_iid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "1. load_data\n",
    "2. generate clients (step 3)\n",
    "3. generate aggregator\n",
    "4. training\n",
    "\"\"\"\n",
    "client_num = 4\n",
    "d = load_cnn_mnist(client_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sigma = 1.0771102905273438\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "FL model parameters.\n",
    "\"\"\"\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "lr = 0.15\n",
    "\n",
    "fl_param = {\n",
    "    'output_size': 10,\n",
    "    'client_num': client_num,\n",
    "    'model': MNIST_CNN,\n",
    "    'data': d,\n",
    "    'lr': lr,\n",
    "    'E': 500,\n",
    "    'C': 1,\n",
    "    'eps': 4.0,\n",
    "    'delta': 1e-5,\n",
    "    'q': 0.01,\n",
    "    'clip': 0.1,\n",
    "    'tot_T': 10,\n",
    "    'batch_size': 128,\n",
    "    'device': device\n",
    "}\n",
    "\n",
    "fl_entity = FLServer(fl_param).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "global epochs = 1, acc = 0.5342  Time taken: 302.74s\n",
      "global epochs = 2, acc = 0.7952  Time taken: 606.97s\n",
      "global epochs = 3, acc = 0.8969  Time taken: 910.56s\n",
      "global epochs = 4, acc = 0.9250  Time taken: 1211.34s\n",
      "global epochs = 5, acc = 0.9406  Time taken: 1516.22s\n",
      "global epochs = 6, acc = 0.9513  Time taken: 1818.05s\n",
      "global epochs = 7, acc = 0.9544  Time taken: 2125.24s\n",
      "global epochs = 8, acc = 0.9579  Time taken: 2427.08s\n",
      "global epochs = 9, acc = 0.9615  Time taken: 2741.97s\n",
      "global epochs = 10, acc = 0.9633  Time taken: 3044.31s\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "acc = []\n",
    "start_time = time.time()\n",
    "for t in range(fl_param['tot_T']):\n",
    "    acc += [fl_entity.global_update()]\n",
    "    print(\"global epochs = {:d}, acc = {:.4f}\".format(t+1, acc[-1]), \" Time taken: %.2fs\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SGD (mnt=0.9)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
